0. Import libraries and load data

In [1]:
from sklearn.datasets import load_files       
from keras.utils import np_utils
import numpy as np
from glob import glob
Using TensorFlow backend.
In [2]:
# define function to load train, test, and validation datasets
def load_dataset(path):
    data = load_files(path)
    paths = np.array(data['filenames'])
    targets = np_utils.to_categorical(np.array(data['target']))
    return paths, targets
In [3]:
all_files, all_targets = load_dataset('../data/all_images')
In [4]:
print('There are a total of %d labeled images in your dataset.' % len(all_files))
There are a total of 10026 labeled images in your dataset.

1. Image preprocessing

In [5]:
# Visualize what the data looks like
import cv2
import matplotlib.pyplot as plt
%matplotlib inline

def visualize_img(img_path, ax):
    img = cv2.imread(img_path)
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    
fig = plt.figure(figsize=(180 ,160))
for i in range(36):
    ax = fig.add_subplot(6, 6, i + 1, xticks=[], yticks=[])
    visualize_img(all_files[i], ax)
In [7]:
# split the dataset into train, valid, and test 
from sklearn.model_selection import train_test_split
X = all_files
y = all_targets
train_files, valid_files, train_targets, valid_targets = train_test_split(X, y, test_size=0.1, random_state=42)
In [8]:
from keras.preprocessing import image                  
from tqdm import tqdm

def path_to_tensor(img_path):
    # loads RGB image as PIL.Image.Image type
    img = image.load_img(img_path, target_size=(224, 224))
    # convert PIL.Image.Image type to 3D tensor with shape (224, 224, 3)
    x = image.img_to_array(img)
    # convert 3D tensor to 4D tensor with shape (1, 224, 224, 3) and return 4D tensor
    return np.expand_dims(x, axis=0)

def paths_to_tensor(img_paths):
    list_of_tensors = [path_to_tensor(img_path) for img_path in tqdm(img_paths)]
    return np.vstack(list_of_tensors)

print("Shape of the tensor:")
print(path_to_tensor(all_files[0]).shape)
print("Shape of the 4D tensor:")
print(paths_to_tensor(train_files).shape)
  0%|          | 8/9023 [00:00<02:14, 66.95it/s]
Shape of the tensor:
(1, 224, 224, 3)
Shape of the 4D tensor:
100%|██████████| 9023/9023 [02:03<00:00, 72.95it/s]
(9023, 224, 224, 3)
In [9]:
print(len(train_files), len(valid_files))
9023 1003
In [10]:
from keras.applications.resnet50 import preprocess_input

train_tensors = preprocess_input(paths_to_tensor(train_files))
valid_tensors = preprocess_input(paths_to_tensor(valid_files))
# test_tensors = preprocess_input(paths_to_tensor(test_files))
100%|██████████| 9023/9023 [02:02<00:00, 73.71it/s]
100%|██████████| 1003/1003 [00:13<00:00, 76.39it/s]
In [11]:
from keras.applications.resnet50 import ResNet50

image_dim = (224, 224, 3)

base_model = ResNet50(include_top=False, input_shape=image_dim, weights='imagenet')
In [12]:
train_Resnet = base_model.predict(train_tensors)
valid_Resnet = base_model.predict(valid_tensors)
# test_Resnet = base_model.predict(test_tensors)
In [13]:
print(train_Resnet.shape)
(9023, 1, 1, 2048)

2. Build the CNN model

In [14]:
from keras.applications import ResNet50, VGG19, InceptionV3
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

for layer in base_model.layers:
    layer.trainable = False
                      
# Train several last layers in base model
for layer in base_model.layers[-22:]:
    layer.trainable = True
    
base_model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D) (None, 230, 230, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 112, 112, 64)  9472                                         
____________________________________________________________________________________________________
bn_conv1 (BatchNormalization)    (None, 112, 112, 64)  256                                          
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 112, 112, 64)  0                                            
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)   (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2a (Conv2D)          (None, 55, 55, 64)    4160                                         
____________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
res2a_branch1 (Conv2D)           (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
bn2a_branch1 (BatchNormalization (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_1 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res2b_branch2a (Conv2D)          (None, 55, 55, 64)    16448                                        
____________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2b_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_6 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2b_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_2 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_7 (Activation)        (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res2c_branch2a (Conv2D)          (None, 55, 55, 64)    16448                                        
____________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_8 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2c_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_9 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2c_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_3 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_10 (Activation)       (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res3a_branch2a (Conv2D)          (None, 28, 28, 128)   32896                                        
____________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_11 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3a_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_12 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3a_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
res3a_branch1 (Conv2D)           (None, 28, 28, 512)   131584                                       
____________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
bn3a_branch1 (BatchNormalization (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_4 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_13 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3b_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_14 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3b_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_15 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3b_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_5 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_16 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3c_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_17 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3c_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_18 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3c_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_6 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_19 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3d_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_20 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3d_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_21 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3d_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_7 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_22 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res4a_branch2a (Conv2D)          (None, 14, 14, 256)   131328                                       
____________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_23 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4a_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_24 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4a_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
res4a_branch1 (Conv2D)           (None, 14, 14, 1024)  525312                                       
____________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
bn4a_branch1 (BatchNormalization (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_8 (Add)                      (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_25 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4b_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_26 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4b_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_27 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4b_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_9 (Add)                      (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_28 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4c_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_29 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4c_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_30 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4c_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_10 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_31 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4d_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_32 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4d_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_33 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4d_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_11 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_34 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4e_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_35 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4e_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_36 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4e_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_12 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_37 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4f_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_38 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4f_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_39 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4f_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_13 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_40 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res5a_branch2a (Conv2D)          (None, 7, 7, 512)     524800                                       
____________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_41 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5a_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_42 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5a_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
res5a_branch1 (Conv2D)           (None, 7, 7, 2048)    2099200                                      
____________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
bn5a_branch1 (BatchNormalization (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_14 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_43 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
res5b_branch2a (Conv2D)          (None, 7, 7, 512)     1049088                                      
____________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_44 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5b_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_45 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5b_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_15 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_46 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
res5c_branch2a (Conv2D)          (None, 7, 7, 512)     1049088                                      
____________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_47 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5c_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_48 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5c_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_16 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_49 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
avg_pool (AveragePooling2D)      (None, 1, 1, 2048)    0                                            
====================================================================================================
Total params: 23,587,712.0
Trainable params: 8,931,328.0
Non-trainable params: 14,656,384.0
____________________________________________________________________________________________________
In [15]:
## from keras.layers import Dense, GlobalAveragePooling2D
from keras.models import Sequential, Model
from keras.layers.normalization import BatchNormalization
from keras.layers import Dropout, Flatten, Dense, GlobalAveragePooling2D

# add a global spatial average pooling layer
x = base_model.output
#x = Flatten()(x)
x = GlobalAveragePooling2D()(x)
x = BatchNormalization()(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.2)(x)
x = BatchNormalization()(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.5)(x)
predictions = Dense(2, activation='softmax')(x)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D) (None, 230, 230, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 112, 112, 64)  9472                                         
____________________________________________________________________________________________________
bn_conv1 (BatchNormalization)    (None, 112, 112, 64)  256                                          
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 112, 112, 64)  0                                            
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)   (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2a (Conv2D)          (None, 55, 55, 64)    4160                                         
____________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2a_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
res2a_branch1 (Conv2D)           (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
bn2a_branch1 (BatchNormalization (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_1 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res2b_branch2a (Conv2D)          (None, 55, 55, 64)    16448                                        
____________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2b_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_6 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2b_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_2 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_7 (Activation)        (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res2c_branch2a (Conv2D)          (None, 55, 55, 64)    16448                                        
____________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_8 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2c_branch2b (Conv2D)          (None, 55, 55, 64)    36928                                        
____________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizatio (None, 55, 55, 64)    256                                          
____________________________________________________________________________________________________
activation_9 (Activation)        (None, 55, 55, 64)    0                                            
____________________________________________________________________________________________________
res2c_branch2c (Conv2D)          (None, 55, 55, 256)   16640                                        
____________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizatio (None, 55, 55, 256)   1024                                         
____________________________________________________________________________________________________
add_3 (Add)                      (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
activation_10 (Activation)       (None, 55, 55, 256)   0                                            
____________________________________________________________________________________________________
res3a_branch2a (Conv2D)          (None, 28, 28, 128)   32896                                        
____________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_11 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3a_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_12 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3a_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
res3a_branch1 (Conv2D)           (None, 28, 28, 512)   131584                                       
____________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
bn3a_branch1 (BatchNormalization (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_4 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_13 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3b_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_14 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3b_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_15 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3b_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_5 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_16 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3c_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_17 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3c_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_18 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3c_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_6 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_19 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res3d_branch2a (Conv2D)          (None, 28, 28, 128)   65664                                        
____________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_20 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3d_branch2b (Conv2D)          (None, 28, 28, 128)   147584                                       
____________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizatio (None, 28, 28, 128)   512                                          
____________________________________________________________________________________________________
activation_21 (Activation)       (None, 28, 28, 128)   0                                            
____________________________________________________________________________________________________
res3d_branch2c (Conv2D)          (None, 28, 28, 512)   66048                                        
____________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizatio (None, 28, 28, 512)   2048                                         
____________________________________________________________________________________________________
add_7 (Add)                      (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
activation_22 (Activation)       (None, 28, 28, 512)   0                                            
____________________________________________________________________________________________________
res4a_branch2a (Conv2D)          (None, 14, 14, 256)   131328                                       
____________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_23 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4a_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_24 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4a_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
res4a_branch1 (Conv2D)           (None, 14, 14, 1024)  525312                                       
____________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
bn4a_branch1 (BatchNormalization (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_8 (Add)                      (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_25 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4b_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_26 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4b_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_27 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4b_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_9 (Add)                      (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_28 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4c_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_29 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4c_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_30 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4c_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_10 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_31 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4d_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_32 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4d_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_33 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4d_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_11 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_34 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4e_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_35 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4e_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_36 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4e_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_12 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_37 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res4f_branch2a (Conv2D)          (None, 14, 14, 256)   262400                                       
____________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_38 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4f_branch2b (Conv2D)          (None, 14, 14, 256)   590080                                       
____________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizatio (None, 14, 14, 256)   1024                                         
____________________________________________________________________________________________________
activation_39 (Activation)       (None, 14, 14, 256)   0                                            
____________________________________________________________________________________________________
res4f_branch2c (Conv2D)          (None, 14, 14, 1024)  263168                                       
____________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizatio (None, 14, 14, 1024)  4096                                         
____________________________________________________________________________________________________
add_13 (Add)                     (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
activation_40 (Activation)       (None, 14, 14, 1024)  0                                            
____________________________________________________________________________________________________
res5a_branch2a (Conv2D)          (None, 7, 7, 512)     524800                                       
____________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_41 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5a_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_42 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5a_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
res5a_branch1 (Conv2D)           (None, 7, 7, 2048)    2099200                                      
____________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
bn5a_branch1 (BatchNormalization (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_14 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_43 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
res5b_branch2a (Conv2D)          (None, 7, 7, 512)     1049088                                      
____________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_44 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5b_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_45 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5b_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_15 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_46 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
res5c_branch2a (Conv2D)          (None, 7, 7, 512)     1049088                                      
____________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_47 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5c_branch2b (Conv2D)          (None, 7, 7, 512)     2359808                                      
____________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizatio (None, 7, 7, 512)     2048                                         
____________________________________________________________________________________________________
activation_48 (Activation)       (None, 7, 7, 512)     0                                            
____________________________________________________________________________________________________
res5c_branch2c (Conv2D)          (None, 7, 7, 2048)    1050624                                      
____________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizatio (None, 7, 7, 2048)    8192                                         
____________________________________________________________________________________________________
add_16 (Add)                     (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
activation_49 (Activation)       (None, 7, 7, 2048)    0                                            
____________________________________________________________________________________________________
avg_pool (AveragePooling2D)      (None, 1, 1, 2048)    0                                            
____________________________________________________________________________________________________
global_average_pooling2d_1 (Glob (None, 2048)          0                                            
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 2048)          8192                                         
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 512)           1049088                                      
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 512)           0                                            
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 512)           2048                                         
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 512)           262656                                       
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 512)           0                                            
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 2)             1026                                         
====================================================================================================
Total params: 24,910,722.0
Trainable params: 24,852,482.0
Non-trainable params: 58,240.0
____________________________________________________________________________________________________
In [16]:
from keras import optimizers 

sgd = optimizers.SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True)

model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])

3. Train

In [17]:
from keras.callbacks import ModelCheckpoint, EarlyStopping

best_weights_path='SATURDAY.weights.best.hdf5'

# train the model
checkpointer = ModelCheckpoint(filepath=best_weights_path, verbose=1, save_best_only=True)

# Stop the training if the model shows no improvement 
stopper = EarlyStopping(monitor='val_loss', min_delta=0.01, patience=5, verbose=1, mode='auto')

history = model.fit(train_tensors, train_targets, epochs=50, batch_size=64,
          validation_data=(valid_tensors, valid_targets),
          callbacks=[checkpointer,stopper], verbose=1, shuffle=True)
Train on 9023 samples, validate on 1003 samples
Epoch 1/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.7303 - acc: 0.6818Epoch 00000: val_loss improved from inf to 0.39954, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 307s - loss: 0.7287 - acc: 0.6825 - val_loss: 0.3995 - val_acc: 0.8405
Epoch 2/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.4461 - acc: 0.8094Epoch 00001: val_loss improved from 0.39954 to 0.31141, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.4473 - acc: 0.8085 - val_loss: 0.3114 - val_acc: 0.8804
Epoch 3/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.3553 - acc: 0.8540Epoch 00002: val_loss improved from 0.31141 to 0.27027, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.3548 - acc: 0.8544 - val_loss: 0.2703 - val_acc: 0.8933
Epoch 4/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.3129 - acc: 0.8722Epoch 00003: val_loss improved from 0.27027 to 0.25385, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.3123 - acc: 0.8724 - val_loss: 0.2539 - val_acc: 0.9073
Epoch 5/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.2762 - acc: 0.8921Epoch 00004: val_loss improved from 0.25385 to 0.24132, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.2755 - acc: 0.8924 - val_loss: 0.2413 - val_acc: 0.9063
Epoch 6/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.2503 - acc: 0.9008Epoch 00005: val_loss improved from 0.24132 to 0.22891, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.2497 - acc: 0.9010 - val_loss: 0.2289 - val_acc: 0.9093
Epoch 7/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.2298 - acc: 0.9098Epoch 00006: val_loss improved from 0.22891 to 0.22673, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.2294 - acc: 0.9100 - val_loss: 0.2267 - val_acc: 0.9153
Epoch 8/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.2017 - acc: 0.9227Epoch 00007: val_loss improved from 0.22673 to 0.21092, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.2025 - acc: 0.9222 - val_loss: 0.2109 - val_acc: 0.9202
Epoch 9/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1961 - acc: 0.9246Epoch 00008: val_loss improved from 0.21092 to 0.20619, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1955 - acc: 0.9247 - val_loss: 0.2062 - val_acc: 0.9212
Epoch 10/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1858 - acc: 0.9283Epoch 00009: val_loss improved from 0.20619 to 0.20058, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1854 - acc: 0.9285 - val_loss: 0.2006 - val_acc: 0.9232
Epoch 11/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1667 - acc: 0.9340Epoch 00010: val_loss improved from 0.20058 to 0.19528, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1668 - acc: 0.9339 - val_loss: 0.1953 - val_acc: 0.9302
Epoch 12/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1541 - acc: 0.9415Epoch 00011: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.1536 - acc: 0.9418 - val_loss: 0.1977 - val_acc: 0.9242
Epoch 13/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1468 - acc: 0.9443Epoch 00012: val_loss improved from 0.19528 to 0.19499, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1467 - acc: 0.9443 - val_loss: 0.1950 - val_acc: 0.9292
Epoch 14/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1301 - acc: 0.9516Epoch 00013: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.1310 - acc: 0.9515 - val_loss: 0.1984 - val_acc: 0.9332
Epoch 15/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1257 - acc: 0.9558Epoch 00014: val_loss improved from 0.19499 to 0.19202, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1255 - acc: 0.9558 - val_loss: 0.1920 - val_acc: 0.9312
Epoch 16/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1199 - acc: 0.9565Epoch 00015: val_loss improved from 0.19202 to 0.18853, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1202 - acc: 0.9563 - val_loss: 0.1885 - val_acc: 0.9312
Epoch 17/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1139 - acc: 0.9583Epoch 00016: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.1141 - acc: 0.9582 - val_loss: 0.1890 - val_acc: 0.9352
Epoch 18/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1104 - acc: 0.9586Epoch 00017: val_loss improved from 0.18853 to 0.18369, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.1109 - acc: 0.9584 - val_loss: 0.1837 - val_acc: 0.9362
Epoch 19/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.1050 - acc: 0.9612Epoch 00018: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.1049 - acc: 0.9612 - val_loss: 0.1902 - val_acc: 0.9362
Epoch 20/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.0990 - acc: 0.9647Epoch 00019: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.0996 - acc: 0.9644 - val_loss: 0.1869 - val_acc: 0.9372
Epoch 21/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.0918 - acc: 0.9663Epoch 00020: val_loss improved from 0.18369 to 0.18318, saving model to SATURDAY.weights.best.hdf5
9023/9023 [==============================] - 291s - loss: 0.0918 - acc: 0.9663 - val_loss: 0.1832 - val_acc: 0.9382
Epoch 22/50
8960/9023 [============================>.] - ETA: 1s - loss: 0.0928 - acc: 0.9669Epoch 00021: val_loss did not improve
9023/9023 [==============================] - 289s - loss: 0.0926 - acc: 0.9669 - val_loss: 0.1859 - val_acc: 0.9352
Epoch 00021: early stopping
In [ ]:
model.load_weights(best_weights_path)
In [9]:
print('Testing Accuracy: {:.4f}'.format(*model.evaluate(test_tensors, test_targets)))
Testing Accuracy: 0.9832
In [24]:
import matplotlib.pyplot as plt

print(history.history.keys())

# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()
dict_keys(['val_acc', 'loss', 'val_loss', 'acc'])

Creating the submission file and predictions

In [25]:
submission_test, submission_targets = load_dataset('../data/here/load')
In [26]:
# save predictions into csv file 
import csv
from tqdm import tqdm_notebook

csv_name = '../predictions/continous_values_17.csv'

with open(csv_name, 'w') as f:
    csvwriter = csv.writer(f)
    csvwriter.writerow(['ID', 'class'])
    for path in tqdm_notebook(sorted(submission_test)):
        tensor = preprocess_input(path_to_tensor(path))
        pred = model.predict(tensor)[0]
        csvwriter.writerow([path, pred[0]])

In [27]:
import pandas as pd
df_submit = pd.read_csv(csv_name)
df_submit.head()
Out[27]:
ID class
0 ../data/here/load/public_test/0.png 0.002478
1 ../data/here/load/public_test/10009.png 0.998833
2 ../data/here/load/public_test/10011.png 0.996648
3 ../data/here/load/public_test/10012.png 0.999612
4 ../data/here/load/public_test/10019.png 0.857385
In [28]:
# Fixing the submission file format to discrete values 0 and 1
import csv
import re
import ntpath
import os

with open('../predictions/friday_discrete_17.csv', 'w') as f:
    csvwriter = csv.writer(f)
    csvwriter.writerow(['ID', 'class'])
    for i, row in df_submit.iterrows():
        head, tail = ntpath.split(row[0])
        idx = os.path.splitext(tail)[0]
        if row[1] > 0.5:
            classy = 0
        else:
            classy = 1
        csvwriter.writerow([idx, classy])